from __future__ import print_function
import os
import sys
import math
import pickle
import pdb
import argparse
import random
from tqdm import tqdm
from shutil import copy
import torch
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import scipy.io
from scipy.linalg import qr 
import igraph
from random import shuffle
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from util import *
from model import *


parser = argparse.ArgumentParser(
    description='Train Variational Autoencoders for DAGs')
parser.add_argument(
    '--nvt', type=int, default=67, help='number of different node types')
parser.add_argument(
    '--data-name', default='final_structures6', help='graph dataset name')
parser.add_argument('--save-appendix', default='', 
    help='what to append to data-name as save-name for results')
parser.add_argument(
    '--save-interval', type=int, default=100, metavar='N',
    help='how many epochs to wait each time to save model states')
parser.add_argument('--continue-from', type=int, default=None, 
    help="from which epoch's checkpoint to continue training")
parser.add_argument('--hs', type=int, default=501, metavar='N',
    help='hidden size of GRUs')
parser.add_argument('--nz', type=int, default=56, metavar='N',
    help='number of dimensions of latent vectors z')
parser.add_argument('--max-node-data', type=int, default=50,
    help='number of dimensions of latent vectors z')
parser.add_argument('--bidirectional', action='store_true', default=False,
    help='whether to use bidirectional encoding')
parser.add_argument('--predictor', action='store_true', default=False,
    help='whether to train a performance predictor from latent\
    encodings and a VAE at the same time')
parser.add_argument('--sample-number', type=int, default=20, metavar='N',
    help='how many samples to generate each time')


parser.add_argument('--model', default='DVAE', help='model to use: DVAE, SVAE, \
                    DVAE_fast, DVAE_BN, SVAE_oneshot, DVAE_GCN')
parser.add_argument('--lr', type=float, default=1e-4, metavar='LR',
                    help='learning rate (default: 1e-4)')
parser.add_argument('--epochs', type=int, default=100000, metavar='N',
                    help='number of epochs to train')
parser.add_argument('--batch-size', type=int, default=1, metavar='N',
                    help='batch size during training')
parser.add_argument('--infer-batch-size', type=int, default=1, metavar='N',
                    help='batch size during inference')
parser.add_argument('--seed', type=int, default=6666, metavar='S',
                    help='random seed (default: 1)')




args = parser.parse_args()
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
device = torch.device("cuda:2")
np.random.seed(args.seed)
random.seed(args.seed)
print(args)

args.file_dir = os.path.dirname(os.path.realpath('__file__'))
args.res_dir = os.path.join(args.file_dir,
    'results/{}{}_pruned_manual_graph_temp_max_n_{}_lr_{}_seed_{}_predictor_{}_hs_{}'.format(args.data_name, 
    args.save_appendix, args.max_node_data, args.lr, args.seed, args.predictor, args.hs))
if not os.path.exists(args.res_dir):
    os.makedirs(args.res_dir) 

train_data, test_data, graph_args = load_ENAS_graphs_with_type_label(
    n_types=args.nvt, max_n_data=args.max_node_data)


cmd_input = 'python ' + ' '.join(sys.argv) + '\n'
with open(os.path.join(args.res_dir, 'cmd_input.txt'), 'a') as f:
    f.write(cmd_input)
print('Command line input: ' + cmd_input + ' is saved.')


model = eval(args.model)(
    graph_args.max_n, graph_args.num_vertex_type, graph_args.START_TYPE,
    graph_args.END_TYPE, hs=args.hs, nz=args.nz, bidirectional=args.bidirectional)

if args.predictor:
    predictor = nn.Linear(args.nz, 3)
    model.predictor = predictor
    model.celoss = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = ReduceLROnPlateau(
    optimizer, 'min', factor=0.1, patience=10, verbose=True)
model.to(device)



if args.continue_from is not None:
    epoch = args.continue_from
    load_module_state(model, os.path.join(args.res_dir, 
        'model_checkpoint{}.pth'.format(epoch)))
    load_module_state(optimizer, os.path.join(args.res_dir, 
        'optimizer_checkpoint{}.pth'.format(epoch)))
    load_module_state(scheduler, os.path.join(args.res_dir, 
        'scheduler_checkpoint{}.pth'.format(epoch)))



def train(epoch):
    model.train()
    train_loss = 0.
    recon_loss = 0.
    kld_loss = 0.
    pred_loss = 0.
    shuffle(train_data)
    pbar = tqdm(train_data)
    g_batch = []
    y_batch = []
    for i, (g, y) in enumerate(pbar):
        g_batch.append(g)
        y_batch.append(y)
        if len(g_batch) == args.batch_size or i == len(train_data) - 1:
            optimizer.zero_grad()
            g_batch = model._collate_fn(g_batch)
            mu, logvar = model.encode(g_batch)
            loss, recon, kld = model.loss(mu, logvar, g_batch)
            if args.predictor:
                y_batch = torch.LongTensor(y_batch).view(-1).to(device)
                y_pred = model.predictor(mu)
                pred = model.celoss(y_pred, y_batch)
                loss += pred
                pbar.set_description(
                    'Epoch: %d, loss: %0.4f, recon: %0.4f, kld: %0.4f, pred: %0.4f'\
                    % (epoch, loss.item()/len(g_batch), recon.item()/len(g_batch), 
                    kld.item()/len(g_batch), pred/len(g_batch)))
            else:
                pbar.set_description(
                    'Epoch: %d, loss: %0.4f, recon: %0.4f, kld: %0.4f' % (
                    epoch, loss.item()/len(g_batch), recon.item()/len(g_batch), 
                    kld.item()/len(g_batch)))
            loss.backward()
            
            train_loss += float(loss)
            recon_loss += float(recon)
            kld_loss += float(kld)
            if args.predictor:
                pred_loss += float(pred)
            optimizer.step()
            g_batch = []
            y_batch = []

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_data)))

    if args.predictor:
        return train_loss, recon_loss, kld_loss, pred_loss
    return train_loss, recon_loss, kld_loss


def test():
    # test recon accuracy
    model.eval()
    encode_times = 10
    decode_times = 10
    Nll = 0
    pred_loss = 0
    n_perfect = 0
    print('Testing begins...')
    pbar = tqdm(test_data)
    g_batch = []
    y_batch = []
    for i, (g, y) in enumerate(pbar):
        g_batch.append(g)
        y_batch.append(y)
        if len(g_batch) == args.infer_batch_size or i == len(test_data) - 1:
            g = model._collate_fn(g_batch)
            mu, logvar = model.encode(g)
            _, nll, _ = model.loss(mu, logvar, g)
            pbar.set_description('nll: {:.4f}'.format(nll.item()/len(g_batch)))
            Nll += nll.item()
            if args.predictor:
                y_batch = torch.FloatTensor(y_batch).unsqueeze(1).to(device)
                y_pred = model.predictor(mu)
                pred = model.mseloss(y_pred, y_batch)
                pred_loss += pred.item()
            for _ in range(encode_times):
                z = model.reparameterize(mu, logvar)
                for _ in range(decode_times):
                    g_recon = model.decode(z)
                    n_perfect += sum(is_same_DAG(g0, g1) for g0, g1 in zip(g, g_recon))
            g_batch = []
            y_batch = []
    Nll /= len(test_data)
    pred_loss /= len(test_data)
    pred_rmse = math.sqrt(pred_loss)
    acc = n_perfect / (len(test_data) * encode_times * decode_times)
    if args.predictor:
        print('Test average recon loss: {0}, recon accuracy: {1:.4f}, pred rmse: {2:.4f}'.format(
            Nll, acc, pred_rmse))
        return Nll, acc, pred_rmse
    else:
        print('Test average recon loss: {0}, recon accuracy: {1:.4f}'.format(Nll, acc))
        return Nll, acc


def extract_latent(data):
    model.eval()
    Z = []
    Y = []
    g_batch = []
    for i, (g, y) in enumerate(tqdm(data)):
        # copy igraph
        # otherwise original igraphs will save the H states and consume more GPU memory
        g_ = g.copy()  
        g_batch.append(g_)
        if len(g_batch) == args.infer_batch_size or i == len(data) - 1:
            g_batch = model._collate_fn(g_batch)
            mu, _ = model.encode(g_batch)
            mu = mu.cpu().detach().numpy()
            Z.append(mu)
            g_batch = []
        Y.append(y)
    return np.concatenate(Z, 0), np.array(Y)


'''Extract latent representations Z'''
def save_latent_representations(epoch):
    Z_train, Y_train = extract_latent(train_data)
    Z_test, Y_test = extract_latent(test_data)
    latent_pkl_name = os.path.join(args.res_dir, args.data_name +
                                   '_latent_epoch{}.pkl'.format(epoch))
    latent_mat_name = os.path.join(args.res_dir, args.data_name + 
                                   '_latent_epoch{}.mat'.format(epoch))
    with open(latent_pkl_name, 'wb') as f:
        pickle.dump((Z_train, Y_train, Z_test, Y_test), f)
    print('Saved latent representations to ' + latent_pkl_name)
    scipy.io.savemat(latent_mat_name, 
        mdict={
            'Z_train': Z_train, 
            'Z_test': Z_test, 
            'Y_train': Y_train, 
            'Y_test': Y_test
            }
        )


'''Training begins here'''
min_loss = math.inf  # >= python 3.5
min_loss_epoch = None
loss_name = os.path.join(args.res_dir, 'train_loss.txt')
loss_plot_name = os.path.join(args.res_dir, 'train_loss_plot.pdf')
test_results_name = os.path.join(args.res_dir, 'test_results.txt')
if os.path.exists(loss_name):
    os.remove(loss_name)



start_epoch = args.continue_from if args.continue_from is not None else 0
for epoch in range(start_epoch + 1, args.epochs + 1):
    if args.predictor:
        train_loss, recon_loss, kld_loss, pred_loss = train(epoch)
    else:
        train_loss, recon_loss, kld_loss = train(epoch)
        pred_loss = 0.0
    with open(loss_name, 'a') as loss_file:
        loss_file.write("{:.2f} {:.2f} {:.2f} {:.2f}\n".format(
            train_loss/len(train_data), 
            recon_loss/len(train_data), 
            kld_loss/len(train_data), 
            pred_loss/len(train_data), 
            ))
    scheduler.step(train_loss)
    if epoch % args.save_interval == 0:
        print("save current model...")
        model_name = os.path.join(
            args.res_dir, 'model_checkpoint{}.pth'.format(epoch))
        optimizer_name = os.path.join(
            args.res_dir, 'optimizer_checkpoint{}.pth'.format(epoch))
        scheduler_name = os.path.join(
            args.res_dir, 'scheduler_checkpoint{}.pth'.format(epoch))
        torch.save(model.state_dict(), model_name)
        torch.save(optimizer.state_dict(), optimizer_name)
        torch.save(scheduler.state_dict(), scheduler_name)
        print("extract latent representations...")
        save_latent_representations(epoch)
        print("sample from prior...")
        sampled = model.generate_sample(args.sample_number)
        for i, g in enumerate(sampled):
            namei = 'graph_{}_sample{}'.format(epoch, i)
            # plot_DAG(g, args.res_dir, namei, data_type=args.data_type)
        print("plot train loss...")
        losses = np.loadtxt(loss_name)
        if losses.ndim == 1:
            continue
        fig = plt.figure()
        num_points = losses.shape[0]
        plt.plot(range(1, num_points+1), losses[:, 0], label='Total')
        plt.plot(range(1, num_points+1), losses[:, 1], label='Recon')
        plt.plot(range(1, num_points+1), losses[:, 2], label='KLD')
        plt.xlabel('Epoch')
        plt.ylabel('Train loss')
        plt.legend()
        plt.savefig(loss_plot_name)














